from sklearn.datasets import fetch_openml
from sklearn.utils import shuffle
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import normalize
import numpy as np
import pandas as pd 

import torch
import torchvision
from torchvision import datasets, transforms
from sklearn.cluster import KMeans
from collections import defaultdict


class load_mnist_1d:
    def __init__(self):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size,
                                      shuffle=True, num_workers=2)
        self.dataiter = iter(train_loader)
        self.n_arm = 10
        self.dim = 7840
 
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0]
        d = d.reshape(784)
        target = y.item()
        X_n = []
        for i in range(10):
            front = np.zeros((784*i))
            back = np.zeros((784*(9 - i)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        return X_n, rwd  


class load_mnist_adv:
    def __init__(self):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
        self.dataiter = []
        for i in range(10):

          label = list(((dataset.train_labels == i).nonzero()).numpy().flatten())
          trainset = torch.utils.data.Subset(dataset, label)

          trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                      shuffle=True, num_workers=2)

          self.dataiter.append(iter(trainloader))

        self.n_arm = 10
        self.dim = 7840
 
    def step(self,i):
        if i != -1:
            prob = np.full(10,0.3/9)
            prob[i] = 0.7
        else:
            prob = np.full(10,0.1)
            
        j = np.random.choice(np.arange(0, 10), p=prob)
        #print(j)
        
        x, y = self.dataiter[j].next()
        d = x.numpy()[0]
        d = d.reshape(784)
        target = y.item()
        X_n = []
        for i in range(10):
            front = np.zeros((784*i))
            back = np.zeros((784*(9 - i)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        #print(rwd)
        return X_n, rwd, j  

    

class load_yelp:
    def __init__(self):
        # Fetch data
        self.m = np.load("./data/yelp_2000users_10000items_entry.npy")
        self.U = np.load("./data/yelp_2000users_10000items_features.npy")
        self.I = np.load("./data/yelp_10000items_2000users_features.npy")
        self.n_arm = 10
        self.dim = 20
        self.pos_index = []
        self.neg_index = []
        for i in self.m:
            if i[2] ==1:
                self.pos_index.append((i[0], i[1]))
            else:
                self.neg_index.append((i[0], i[1]))   
            
        self.p_d = len(self.pos_index)
        self.n_d = len(self.neg_index)
        print(self.p_d, self.n_d)
        self.pos_index = np.array(self.pos_index)
        self.neg_index = np.array(self.neg_index)


    def step(self):        
        arm = np.random.choice(range(10))
        #print(pos_index.shape)
        pos = self.pos_index[np.random.choice(range(self.p_d), 9, replace=False)]
        neg = self.neg_index[np.random.choice(range(self.n_d), replace=False)]
        X_ind = np.concatenate((pos[:arm], [neg], pos[arm:]), axis=0)
        X = []
        for ind in X_ind:
            #X.append(np.sqrt(np.multiply(self.I[ind], u_fea)))
            X.append(np.concatenate((self.U[ind[0]], self.I[ind[1]])))
        rwd = np.zeros(self.n_arm)
        rwd[arm] = 1
        return np.array(X), rwd

    

class load_movielen:
    def __init__(self):
        # Fetch data
        self.m = np.load("./data/movie_2000users_10000items_entry.npy")
        self.U = np.load("./data/movie_2000users_10000items_features.npy")
        self.I = np.load("./data/movie_10000items_2000users_features.npy")
        self.n_arm = 10
        self.dim = 20
        self.pos_index = []
        self.neg_index = []
        for i in self.m:
            if i[2] ==1:
                self.pos_index.append((i[0], i[1]))
            else:
                self.neg_index.append((i[0], i[1]))   
            
        self.p_d = len(self.pos_index)
        self.n_d = len(self.neg_index)
        print(self.p_d, self.n_d)
        self.pos_index = np.array(self.pos_index)
        self.neg_index = np.array(self.neg_index)


    def step(self):        
        arm = np.random.choice(range(10))
        #print(pos_index.shape)
        pos = self.pos_index[np.random.choice(range(self.p_d), 9, replace=False)]
        neg = self.neg_index[np.random.choice(range(self.n_d), replace=False)]
        X_ind = np.concatenate((pos[:arm], [neg], pos[arm:]), axis=0)
        X = []
        for ind in X_ind:
            #X.append(np.sqrt(np.multiply(self.I[ind], u_fea)))
            X.append(np.concatenate((self.U[ind[0]], self.I[ind[1]])))
        rwd = np.zeros(self.n_arm)
        rwd[arm] = 1
        return np.array(X), rwd

class load_notmnist_mnist_2:
    def __init__(self):       
        #  mnist
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size,
                                      shuffle=True, num_workers=2)
        self.dataiter = iter(train_loader)
        self.n_arm = np.max(self.y_arm) + 1
        self.dim = self.X.shape[1] + 9

    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0]
        d = d.reshape(self.act_dim )
        target = y.item()
        X = np.zeros((self.n_arm, self.dim))
        for a in range(self.n_arm):
            X[a, a:a+
                self.act_dim] = d
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        return X, rwd  


class Bandit_multi:
    def __init__(self, name):
        # Fetch data
        if name == 'covertype':
            X, y = fetch_openml('covertype', version=3, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X,y)
            # class: 1-7
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            #X = X.to_numpy()
            X = normalize(X)
        elif name == 'MagicTelescope':
            X, y = fetch_openml('MagicTelescope', version=1, return_X_y=True)
            # class: h, g
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'shuttle':
            X, y = fetch_openml('shuttle', version=1, return_X_y=True)
            # avoid nan, set nan as -1
            # print(X,y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'adult':
            X, y = fetch_openml('adult', version=2, return_X_y=True)
            
            X = pd.get_dummies(X)
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'mushroom':
            X, y = fetch_openml('mushroom', version=1, return_X_y=True)
            # print(X,y,X.info())
            X = pd.get_dummies(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'fashion':
            X, y = fetch_openml('Fashion-MNIST', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X,y,X.info())
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'nursery':
            X, y = fetch_openml('nursery', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X)
            # print(y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
        elif name == 'Plants':
            X, y = fetch_openml('nursery', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X)
            # print(y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            
        elif name == 'leaf':
            X, y = fetch_openml('leaf', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
        elif name == 'eucalyptus':
            X, y = fetch_openml('eucalyptus', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
        else:
            raise RuntimeError('Dataset does not exist')
        # Shuffle data
        self.X, self.y = shuffle(X, y)
        # generate one_hot coding:
        self.y_arm = np.array(self.y.values).astype(np.int64)
        if name != 'fashion':
          self.y_arm = self.y_arm - 1
        # cursor and other variables
        self.cursor = 0
        self.size = self.y.shape[0]
        self.n_arm = int(np.max(self.y_arm)+1)
        #self.n_arm = int(np.max(self.y_arm)/2+1)
        print(np.unique(self.y_arm),self.n_arm)
        print(self.X.shape[1])
        self.dim = self.X.shape[1]  + self.n_arm
        #self.dim = self.X.shape[1]  * self.n_arm
        self.act_dim = self.X.shape[1]
        self.num_user = np.max(self.y_arm)+1
        # print(self.dim)
        # print(self.n_arm)
        self.input_ = [self.X[self.y_arm == i] for i in range(self.n_arm)]
        # for i in range(5):
        #   print(self.X[i], self.y_arm[i])
        # for i in range(self.n_arm):
        #   self.input_.append(self.X[self.y_arm == i])
        #print(self.input_[0][0])
        

    def step(self,i):
        if self.cursor > (len(self.X)-1):
            self.cursor = 0
        
        if i != -1:
            if self.n_arm == 2:
              p = 0.6
            else:
              p = 0.4
            prob = np.full(self.n_arm,(1-p)/(self.n_arm - 1))
            prob[i] = p
        else:
            prob = np.full(self.n_arm,1/self.n_arm)
            
        j = np.random.choice(np.arange(0, self.n_arm), p=prob)
        #print(self.input_)
        x = self.input_[j][0]
        #print(x,j)
        np.roll(self.input_[j],-1)
        y = j
        target = int(y)
        X_n = []
        for i in range(self.n_arm):

            ##########################
            # front = np.zeros((self.X.shape[1]*i))
            # back = np.zeros(self.X.shape[1]*(self.n_arm-1-i))
            # new_d = np.concatenate((front, x, back), axis=0)
            ##########################
            front = np.zeros((1*i))
            back = np.zeros((1*(self.n_arm - i)))
            new_d = np.concatenate((front, x, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        self.cursor += 1
        return X_n, rwd, j


class load_emnist_letter_1d:
    def __init__(self, is_shuffle=True):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        trainset = torchvision.datasets.EMNIST(root='./data', split = "letters", train=True,
                                        download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
        self.dataiter = iter(trainloader)

        self.n_arm = 26
        self.num_zeros = 10
        self.num_class = 26
        self.num_user = 26
        self.dim = 28*28 + self.num_zeros*(self.num_class - 1)

        
        
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0][0].reshape(28*28)
        target = y.item()-1
        X_n = []
        for i in range(self.n_arm):
            front = np.zeros((self.num_zeros*i))
            back = np.zeros((self.num_zeros*(self.num_class - i-1)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        return X_n, rwd

class synthetic:
    def __init__(self, name):
        self.name = name
        self.n_arm = 4
        #self.x = np.zeros(self.K)
        self.dim = 20
        self.a = np.random.randn(self.dim, 1)
        self.A = np.random.normal(0,1,(self.dim,self.dim))
        #self.reward = np.zeros(self.K)
    def step(self,t):
        x = [0]*self.n_arm
        for i in range(self.n_arm):
            x[i] = np.random.randn(1, self.dim)[0]
            #print(x[i],np.linalg.norm(x[i], axis=0))
            x[i] /= np.linalg.norm(x[i], axis=0)
            #print(x[i])
        rwd = np.zeros(self.n_arm)
        if self.name == 'cos':
            for i in range(self.n_arm):
                rwd[i] = np.cos(3*np.dot(x[i],self.a)) + np.random.normal(0, 0.1)
        elif self.name == 'square':
            for i in range(self.n_arm):
                rwd[i] = 10*(np.dot(x[i],self.a))**2 + np.random.normal(0, 0.1)
        else:
            AtA = np.matmul(np.transpose(self.A),self.A)
            for i in range(self.n_arm):
                rwd[i] = (np.dot(x[i],np.matmul(AtA,x[i]))) + np.random.normal(0, 0.1)

        X_n = np.array([x[i] for i in range(self.n_arm)])
        return X_n, rwd, 1

    

    

